library(tidyverse)
library(maptools)
library(igraph)
library(parallel)
library(redist)
library(spdep)
library(gridExtra)

ipw <- function(x, beta, pop){
    indpop <- which(x$distance_parity <= pop)
    indbeta <- which(x$beta_sequence == beta)
    inds <- intersect(indpop, indbeta)
    psi <- x$constraint_pop[inds]
    w <- 1 / exp(beta * psi)
    inds <- sample(inds, length(inds), replace = TRUE, prob = w)
    x <- x$partitions[,inds]
    return(x)
}
ben_theme <- function(){
    theme_classic() +
        theme(panel.background = element_blank(),
              panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(),
              axis.line = element_line(colour = "black"),
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              strip.background = element_blank(),
              legend.position = "bottom", legend.title = element_blank(),
              plot.title = element_text(hjust = 0.5),
              plot.subtitle = element_text(hjust = .5))
}

fl_map <- readShapePoly("../data/fl/FL.shp")

## Load nodes
mn <- 89 
nodes <- strsplit(readLines("../data/largemap_enum_all/map_sub_nodes_89.dat")," ")[[1]]

## First, subset down to right map
fl_sub <- fl_map[which(rownames(fl_map@data) %in% nodes),]

## Then, reorder nodes to be correct order to match - otherwise, won't
## line up properly with original map
fl_sub <- fl_sub[match(nodes, rownames(fl_sub@data)),]

## Convert factor columns to characters, because this particular data is messy
indx <- sapply(fl_sub@data, is.factor)
fl_sub@data[indx] <- lapply(
    fl_sub@data[indx], function(x) as.numeric(as.character(x))
)

## ---------------------------
## Reorder edges and enumerate
## ---------------------------
system("python ndscut.py <../data/largemap_enum_all/map_sub_89.dat >../data/largemap_enum_all/map_sub_ordered_89.dat")

system("../enumpart_private/enumpart ../data/largemap_enum_all/map_sub_ordered_89.dat -k 2 -allsols > ../data/largemap_enum_all/solutions_89.dat")

## -------------
## Get solutions
## -------------
lines_to_read <- seq(from = 0, to = 44082156, by = 100000)
lines_to_read[length(lines_to_read)] <- 44082156
seg_out <- rep(NA, 44082156)
parity_out <- rep(NA, 44082156)
for(i in 2:length(lines_to_read)){

    ## Load in solutions
    sols_all <- read_lines("../data/largemap_enum_all/solutions_89.dat",
                           skip = lines_to_read[i-1],
                           n_max = lines_to_read[i] - lines_to_read[i-1])

    ## Create edgelist
    el <- strsplit(readLines("../data/largemap_enum_all/map_sub_ordered_89.dat"), split = " ")
    el <- apply(do.call(rbind, el), 2, as.numeric)

    ## All solutions to matrix
    ndists <- 2
    sols_out_all <- mclapply(1:length(sols_all), function(x){
        sol_split <- as.numeric(strsplit(sols_all[x], split = " ")[[1]])
        el_sub <- el[sol_split,]
        comps_out <- components(graph_from_edgelist(el_sub, directed = FALSE))
        if(comps_out$no < ndists){
            max_num <- max(comps_out$membership)
            comps_out <- c(comps_out$membership, (max_num + 1):ndists)
        }else{
            comps_out <- comps_out$membership
        }
        return(comps_out)
    }, mc.cores = detectCores()-1)
    sols_out_all <- do.call(cbind, sols_out_all)

    ## Calculate dissimilarity index and parity
    seg_truth_sample <- unlist(mclapply(1:ncol(sols_out_all), function(x){
        T <- sum(fl_sub@data$pop)
        P <- sum(fl_sub@data$mccain) / T
        t_i <- tapply(fl_sub@data$pop, sols_out_all[,x], sum)
        p_i <- tapply(fl_sub@data$mccain, sols_out_all[,x], sum) / t_i
        return(sum(t_i * abs(p_i - P) / (2 * T * P * (1 - P)), na.rm = TRUE))
    }, mc.cores = detectCores()-1))

    pop_dist_sample <- unlist(mclapply(1:ncol(sols_out_all), function(x){
        distpop <- tapply(fl_sub@data$pop, sols_out_all[,x], sum)
        parpop <- sum(distpop) / length(distpop)
        return(max(abs(distpop / parpop - 1)))
    }, mc.cores = detectCores()-1))

    ## Store
    seg_out[(lines_to_read[i-1]+1):lines_to_read[i]] <- seg_truth_sample
    parity_out[(lines_to_read[i-1]+1):lines_to_read[i]] <- pop_dist_sample

    cat(paste0("Done with ", lines_to_read[i], " solutions out of 44082156 at ", Sys.time(), ".\n"))

}
out <- data.frame(seg = seg_out, parity = parity_out)
write_csv(out, path = "../data/largemap_enum_all/segregation_89.csv")

## ---------------
## Run simulations
## ---------------
seg_truth <- read_csv("../data/largemap_enum_all/segregation_89.csv")
pop <- fl_sub@data$pop
rep <- fl_sub@data$mccain
nsims <- 50000
nchains <- 4
set.seed(38637) ## Random.org

## Unconstrained
seg_unc <- unlist(
    mclapply(1:nchains, function(x){
        out_unc <- redist.mcmc(fl_sub, popvec = pop,
                               nsims = nsims, ndists = 2)
        return(redist.segcalc(out_unc, grouppop = rep, fullpop = pop))
    }, mc.cores = detectCores())
)

## 5%
seg_05 <- unlist(
    mclapply(1:nchains, function(x){
        out_05 <- redist.mcmc(fl_sub, popvec = pop,
                              nsims = nsims, ndists = 2,
                              constraint = "population", beta = -10)
        out_05 <- ipw(out_05, beta = -10, pop = .05)
        return(redist.segcalc(out_05, grouppop = rep, fullpop = pop))
    }, mc.cores = detectCores())
)

seg_05_hard <- unlist(
    mclapply(1:nchains, function(x){
        out_05_hard <- redist.mcmc(fl_sub, popvec = pop,
                              nsims = nsims, ndists = 2,
                              popcons = .05)
        return(redist.segcalc(out_05_hard, grouppop = rep, fullpop = pop))
    }, mc.cores = detectCores())
)

## 1%
seg_01 <- unlist(
    mclapply(1:nchains, function(x){
        out_01 <- redist.mcmc(fl_sub, popvec = pop,
                              nsims = nsims, ndists = 2,
                              constraint = "population", beta = -50)
        out_01 <- ipw(out_01, beta = -50, pop = .01)
        return(redist.segcalc(out_01, grouppop = rep, fullpop = pop))
    }, mc.cores = detectCores())
)

seg_01_hard <- unlist(
    mclapply(1:nchains, function(x){
        out_01_hard <- redist.mcmc(fl_sub, popvec = pop,
                              nsims = nsims, ndists = 2,
                              popcons = .01)
        return(redist.segcalc(out_01_hard, grouppop = rep, fullpop = pop))
    }, mc.cores = detectCores())
)

## Load RSG sims
load("../data/largemap_enum_all/simulations_rsg.RData")

rsg_unc <- do.call("cbind", lapply(sims_out, "[[", 1))
seg_unc_rsg <- redist.segcalc(rsg_unc, grouppop = rep, fullpop = pop)

rsg_05 <- do.call("cbind", lapply(sims_out, "[[", 2))
seg_05_rsg <- redist.segcalc(rsg_05, grouppop = rep, fullpop = pop)

rsg_01 <- do.call("cbind", lapply(sims_out, "[[", 3))
seg_01_rsg <- redist.segcalc(rsg_01, grouppop = rep, fullpop = pop)

## -----------
## Create plot
## -----------
df_plot <- data.frame(
    seg = c(seg_truth$seg,
            seg_unc,
            seg_unc_rsg,
            seg_truth$seg[seg_truth$parity <= .05],
            seg_05,
            seg_05_hard,
            seg_05_rsg,
            seg_truth$seg[seg_truth$parity <= .01],
            seg_01,
            seg_01_hard,
            seg_01_rsg),
    distribution = c(rep("Truth", nrow(seg_truth)),
                     rep("MCMC (Hard)", length(seg_unc)),
                     rep("RSG", length(seg_unc_rsg)),
                     rep("Truth", nrow(seg_truth[seg_truth$parity <= .05,])),
                     rep("MCMC", length(seg_05)),
                     rep("MCMC (Hard)", length(seg_05_hard)),
                     rep("RSG", length(seg_05_rsg)),
                     rep("Truth", nrow(seg_truth[seg_truth$parity <= .01,])),
                     rep("MCMC", length(seg_01)),
                     rep("MCMC (Hard)", length(seg_01_hard)),
                     rep("RSG", length(seg_01_rsg))),
    pop_parity = c(rep("No Constraint", nrow(seg_truth)),
                   rep("No Constraint", length(seg_unc)),
                   rep("No Constraint", length(seg_unc_rsg)),
                   rep("5% Parity", nrow(seg_truth[seg_truth$parity <= .05,])),
                   rep("5% Parity", length(seg_05)),
                   rep("5% Parity", length(seg_05_hard)),
                   rep("5% Parity", length(seg_05_rsg)),
                   rep("1% Parity", nrow(seg_truth[seg_truth$parity <= .01,])),
                   rep("1% Parity", length(seg_01)),
                   rep("1% Parity", length(seg_01_hard)),
                   rep("1% Parity", length(seg_01_rsg)))) %>%
    mutate(
        pop_parity = case_when(pop_parity == "No Constraint"~"No Equal Population Constraint",
                               pop_parity == "5% Parity"~"5% Equal Population Constraint",
                               pop_parity == "1% Parity"~"1% Equal Population Constraint",
                               TRUE~NA_character_),
        pop_parity = factor(pop_parity,
                            levels = c("No Equal Population Constraint", "5% Equal Population Constraint", "1% Equal Population Constraint")),
        distribution = as.character(distribution),
        distribution = case_when(
            distribution == "MCMC (Hard)"~"MCMC\n(Hard Constraint)",
            distribution == "MCMC"~"MCMC\n(Soft Constraint)",
            TRUE~distribution
        ),
        distribution = factor(
            distribution,
            levels = c("Truth", "MCMC\n(Hard Constraint)",
                       "MCMC\n(Soft Constraint)", "RSG")
        )
    ) %>% filter(!(distribution == "MCMC\n(Hard Constraint)" & pop_parity != "No Equal Population Constraint")) %>%
    mutate(distribution = as.character(distribution),
           distribution = case_when(distribution == "RSG"~"RSG",
                                    distribution == "Truth"~"Truth",
                                    TRUE~"MCMC"),
           distribution = factor(distribution,
                                 levels = c("Truth", "MCMC", "RSG")))

ggplot(df_plot, aes(seg, colour = distribution, fill = distribution,
                    alpha = distribution, lty = distribution)) +
    geom_density(lwd = 1.1) +
    facet_grid(~ pop_parity) +
    theme_classic() +
    scale_colour_manual(values = c("grey", "black", "red")) +
    scale_fill_manual(values = c("grey", "white", "white")) +
    scale_linetype_manual(values = c(1, 1, 2)) + 
    scale_alpha_manual(values = c(1, 0, 0)) + 
    labs(x = "Republican Dissimilarity Index", y = "Density") +
    ben_theme() +
    theme(legend.position = c(.22, .7)) +
    ## guides(colour = guide_legend(nrow = 1, byrow = TRUE)) +
    ggsave("../paper/figs/largemap_enum_all_validation.pdf", height = 3, width = 8)

## -------------------------------
## Descriptive plot of the new map
## -------------------------------
fl_sub@data$id <- rownames(fl_sub@data)
fl_sub_points <- fortify(fl_sub, region = "id")
fl_sub_df <- plyr::join(fl_sub_points, fl_sub@data, by = "id")

## Map plot
map_plot <- ggplot(fl_sub_df) + 
    aes(long, lat, group = group, fill = pop) + 
    geom_polygon(colour = "black", size = .3) +
    geom_path(color="black", lwd = .05) + 
    theme_classic() +
    scale_fill_gradient(low = "white", high = "black") +
    coord_equal() +
    guides(fill = guide_legend(title = "Population")) +
    labs(x = "Longitude", y = "Latitude",
         title = "70-Precinct Validation Map") + 
    theme(panel.border = element_rect(colour = "black", fill=NA),
          plot.title = element_text(hjust = 0.5),
          legend.position = c(.11, .15),
          legend.title = element_text(size = 8), 
          legend.text = element_text(size = 8))

## Parity plot
parity_plot_df <- seg_truth %>% filter(parity <= .2) %>%
    mutate(
        parity_group = cut(parity, breaks = seq(0, max(parity), .01),
                           include.lowest= TRUE, labels = FALSE)
    ) %>%
    group_by(pg = parity_group / 100) %>%
    summarize(n = n())
numsols_df <- c(
    sum(seg_truth$parity <= .01),
    sum(seg_truth$parity <= .05),
    sum(seg_truth$parity <= .1)
)
text_out <- data.frame(text = paste0(paste0("< ", format(c(.01, .05, .10), nsmall = 2), " Parity: ", numsols_df), collapse = "\n"))

parity_plot <- ggplot(parity_plot_df, aes(pg - .005, n)) +
    geom_bar(stat = "identity") +
    theme_classic() +
    scale_y_continuous(labels = scales::comma) +
    geom_text(data = text_out, aes(x = .1, y = Inf, label = text),
              vjust = 1.2, hjust = 1.1) + 
    labs(x = "Distance from Population Parity",
         y = "Number of Plans",
         title = "Distribution of Population Parity on 70-Precinct Map") + 
    theme(panel.border = element_rect(colour = "black", fill=NA),
          plot.title = element_text(hjust = 0.5),
          legend.position = c(.05, .05))

pdf(file = "../paper/figs/map_descriptives_70precinct.pdf", height = 5, width = 10)
grid.arrange(map_plot, parity_plot, widths = c(2, 2), nrow = 1)
dev.off()
